from stable_baselines3.common.callbacks import BaseCallback
import utils
import numpy as np

class ContinualEvalCallback(BaseCallback):
    """
    A custom callback that derives from ``BaseCallback``.

    :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
    """
    def __init__(self, log_freq = 10, plot_freq = 1000,\
                state_transformation = None, fname = None, plot = False, verbose = 0, use_emw = False, alpha = 0.1):
        super(ContinualEvalCallback, self).__init__(verbose)
        # Those variables will be accessible in the callback
        # (they are defined in the base class)
        # The RL model
        # self.model = None  # type: BaseAlgorithm
        # An alias for self.model.get_env(), the environment used for training
        # self.training_env = None  # type: Union[gym.Env, VecEnv, None]
        # Number of time the callback was called
        # self.n_calls = 0  # type: int
        # self.num_timesteps = 0  # type: int
        # local and global variables
        # self.locals = None  # type: Dict[str, Any]
        # self.globals = None  # type: Dict[str, Any]
        # The logger object, used to report things in the terminal
        # self.logger = None  # stable_baselines3.common.logger
        # # Sometimes, for event callback, it is useful
        # # to have access to the parent object
        # self.parent = None  # type: Optional[BaseCallback]
        self.log_freq = log_freq
        self.plot_freq = plot_freq
        self.plot = plot
        self.state_transformation = state_transformation
        self.fname = fname
        self.backlog = []
        self.time = []
        self.use_emw = use_emw
        self.alpha = alpha
        self.avg_backlog = []

    def _on_step(self) -> bool:
        """
        This method will be called by the model after each call to `env.step()`.

        For child callback (of an `EventCallback`), this will be called
        when the event is triggered.

        :return: (bool) If the callback returns False, training is aborted early.
        """
        time = self.locals['infos'][0]['time']
        if time % self.log_freq == 0:
            self.backlog.append(self.locals['infos'][0]['backlog'])
            self.time.append(time)
        
        if self.plot and (time % self.plot_freq == 0):
            denom = np.arange(1, len(self.backlog) + 1)
            avg_backlog = np.divide(np.cumsum(self.backlog), denom)
            # print (self.time[-1], avg_backlog, self.locals['infos'][0]['native_state'],
            #     self.locals['infos'][0]['action'],
            #     self.locals['infos'][0]['next_native_state'])#, self.model.kl_coef)
            print (self.time[-1], avg_backlog)#, self.model.kl_coef)

        # if (time % 100000) == 0:
        #     utils.plot_heatmap(self.model, '{}_{}_{}'.format(self.fname, self.model.variant + '-' + self.state_transformation, time),\
        #         transformation = self.state_transformation, within_callback = True, env = self.training_env)
        return True

    def on_rollout_end(self):
        time = self.locals['infos'][0]['time']
        if self.plot and (time % 10240 == 0):
            advantages = np.abs(self.locals['rollout_buffer'].advantages)
            advantages = (advantages - advantages.min()) / (advantages.max() - advantages.min() + 1e-8)
            print ('adv: {}'.format(np.mean(advantages)))

    def get_stats(self):
        stats = {
            'backlog': self.backlog,
            'time': self.time
        }
        return stats
